import os
import pprint
from dataclasses import asdict
import warnings
warnings.filterwarnings('ignore')

import bullet_safety_gym

try:
    import safety_gymnasium
except ImportError:
    print("safety_gymnasium is not found.")
# import gymnasium as gym
from gym_minigrid.register import env_list
from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
from gym_minigrid.wrappers import *
import gym
import pyrallis
import torch
import torch.nn as nn
from tianshou.data import VectorReplayBuffer
from tianshou.env import BaseVectorEnv, ShmemVectorEnv, SubprocVectorEnv
from fsrl.utils.exp_util import load_config_and_model, seed_all
from tianshou.utils.net.common import Net
# from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.discrete import Actor
from torch.distributions import Independent, Normal, Categorical

from fsrl.config.cvpo_disc_cfg import (
    Bullet1MCfg,
    Bullet5MCfg,
    Bullet10MCfg,
    Mujoco2MCfg,
    Mujoco5MCfg,
    Mujoco10MCfg,
    Mujoco20MCfg,
    MujocoBaseCfg,
    TrainCfg,
)
from fsrl.data import FastCollector, FastCollectorSimp
from fsrl.policy import CVPO_Disc
from fsrl.trainer import OffpolicyTrainer
from fsrl.utils import TensorboardLogger, WandbLogger
from fsrl.utils.exp_util import auto_name, seed_all
from fsrl.utils.net.common import ActorCritic, CNNActor, BudgetaryShieldGT, BudgetaryShieldLLM, BudgetaryShieldLLMNew, BudgetaryShieldLLMNewDeepSeek
from fsrl.utils.net.continuous import DoubleCritic, SingleCritic, DoubleCriticCNN, SingleCriticCNN, EnsembleCostModel
from fsrl.utils.net.discrete import EnsembleActor

TASK_TO_CFG = {
    0: TrainCfg,
    1: TrainCfg,
    2: TrainCfg,
    3: TrainCfg,
    4: TrainCfg,
    5: TrainCfg,
    6: TrainCfg,
    7: TrainCfg,
    8: TrainCfg,
    9: TrainCfg,
    10: TrainCfg,
    11: TrainCfg,
    12: TrainCfg,
    13: TrainCfg,
    14: TrainCfg,
    15: TrainCfg,
}


@pyrallis.wrap()
def train(args: TrainCfg):
    # set seed and computing
    seed_all(int(args.seed))
    torch.set_num_threads(args.thread)

    task = int(args.task)
    env_name = env_list[task]
    default_cfg = TASK_TO_CFG[task]() if task in TASK_TO_CFG else TrainCfg()
    # use the default configs instead of the input args.
    # if args.use_default_cfg:
    #     default_cfg.task = args.task
    #     default_cfg.seed = args.seed
    #     default_cfg.device = args.device
    #     default_cfg.logdir = args.logdir
    #     default_cfg.project = args.project
    #     default_cfg.group = args.group
    #     default_cfg.suffix = args.suffix
    #     args = default_cfg

    # setup logger
    cfg = asdict(args)
    default_cfg = asdict(default_cfg)
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = env_name + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.project, args.group)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    if args.pretrain_path is not None and args.pretrain_load:
        cfg, model = load_config_and_model(args.pretrain_path, False, device=args.device)

    training_num = min(args.training_num, args.episode_per_collect)
    worker = eval(args.worker)
    train_envs = worker([lambda: StandardSafetyGymWrapper(gym.make(env_name), int_reward=args.int_reward) for _ in range(training_num)])
    test_envs = worker([lambda: StandardSafetyGymWrapper(gym.make(env_name), int_reward=args.int_reward) for _ in range(args.testing_num)])
    # train_envs = [StandardSafetyGymWrapper(gym.make(env_name))]
    # test_envs = [StandardSafetyGymWrapper(gym.make(env_name))]

    # model
    env = gym.make(env_name)
    env = StandardSafetyGymWrapper(env, int_reward=args.int_reward)
    state_shape = env.observation_space.shape or env.observation_space.n
    action_shape = env.action_space.n
    # max_action = env.action_space.high[0]

    assert hasattr(
        env.spec, "max_episode_steps"
    ), "Please use an env wrapper to provide 'max_episode_steps' for CVPO"

    # net = Net(state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
    # actor = Actor(
    #     net,
    #     action_shape,
    #     device=args.device,
    #     softmax_output=False,
    # ).to(args.device)
    # print(state_shape)
    actor = CNNActor(
        state_shape[0],
        action_shape,
        hidden_sizes=args.hidden_sizes,
        device=args.device,
        softmax_output=False,
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

    critics = []
    for i in range(2):
        if args.double_critic:
            net1 = Net(
                state_shape[0]-25+12,
                action_shape,
                hidden_sizes=args.hidden_sizes,
                concat=True,
                device=args.device
            )
            net2 = Net(
                state_shape[0]-25+12,
                action_shape,
                hidden_sizes=args.hidden_sizes,
                concat=True,
                device=args.device
            )
            critics.append(DoubleCriticCNN(net1, net2, device=args.device).to(args.device))
        else:
            net_c = Net(
                state_shape[0]-25+12,
                action_shape,
                hidden_sizes=args.hidden_sizes,
                concat=True,
                device=args.device
            )
            critics.append(SingleCriticCNN(net_c, device=args.device).to(args.device))

    critic_optim = torch.optim.Adam(
        nn.ModuleList(critics).parameters(), lr=args.critic_lr
    )

    # if not args.conditioned_sigma:
    #     torch.nn.init.constant_(actor.sigma_param, -0.5)
    actor_critic = ActorCritic(actor, critics)
    # orthogonal initialization
    for m in actor_critic.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal_(m.weight)
            torch.nn.init.zeros_(m.bias)

    if args.last_layer_scale:
        # do last policy layer scaling, this will make initial actions have (close to)
        # 0 mean and std, and will help boost performances,
        # see https://arxiv.org/abs/2006.05990, Fig.24 for details
        for m in actor.mu.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.zeros_(m.bias)
                m.weight.data.copy_(0.01 * m.weight.data)

    # def dist(*logits):
    #     return Independent(Normal(*logits), 1)
    def dist(logits):
        return Categorical(logits=logits)
    
    if not args.llm_shield:
        shield = BudgetaryShieldGT(env.avoid_obj, env.hc)
    else:
        # shield = BudgetaryShieldLLM(env.avoid_obj, env.hc, device=args.device)
        if args.use_deepseek:
            shield = BudgetaryShieldLLMNewDeepSeek(env.avoid_obj, env.hc, device=args.device)
        else:
            shield = BudgetaryShieldLLMNew(env.avoid_obj, env.hc, device=args.device)

    # shield distill for llm shield
    cost_shield_distill = EnsembleCostModel(state_shape[0], action_shape, args.hidden_sizes, device=args.device)
    actor_shield_distill = EnsembleActor(state_shape[0], action_shape, args.hidden_sizes, softmax_output=True, device=args.device)

    cost_shield_optim = torch.optim.Adam(cost_shield_distill.parameters(), lr=args.distill_lr)
    actor_shield_optim = torch.optim.Adam(actor_shield_distill.parameters(), lr=args.distill_lr)


    policy = CVPO_Disc(
        actor=actor,
        critics=critics,
        actor_optim=actor_optim,
        critic_optim=critic_optim,
        logger=logger,
        action_space=env.action_space,
        dist_fn=dist,
        max_episode_steps=args.max_episode_steps,
        cost_limit=args.cost_limit,
        tau=args.tau,
        gamma=args.gamma,
        n_step=args.n_step,
        # E-step
        estep_iter_num=args.estep_iter_num,
        estep_kl=args.estep_kl,
        estep_dual_max=args.estep_dual_max,
        estep_dual_lr=args.estep_dual_lr,
        sample_act_num=args.sample_act_num,  # for continous action space
        # M-step
        mstep_iter_num=args.mstep_iter_num,
        mstep_kl=args.mstep_kl,
        mstep_dual_max=args.mstep_dual_max,
        mstep_dual_lr=args.mstep_dual_lr,
        deterministic_eval=args.deterministic_eval,
        action_scaling=args.action_scaling,
        action_bound_method=args.action_bound_method,
        lr_scheduler=None,
        disc_sample=args.disc_sample,
        cost_shield_distill=cost_shield_distill,
        actor_shield_distill=actor_shield_distill,
        cost_shield_optim=cost_shield_optim,
        actor_shield_optim=actor_shield_optim
    )
    if args.pretrain_path is not None and args.pretrain_load:
        # print(model["model"])
        policy.load_state_dict(model["model"])

    # collector
    train_collector = FastCollector(
        policy,
        train_envs,
        # env,
        VectorReplayBuffer(args.buffer_size, len(train_envs)),
        exploration_noise=False,
        shield=shield,
        use_shield=args.use_shield,
        use_shield_distill=args.use_shield_distill,
        use_fake_data=args.use_fake_data,
        fake_next_obs=args.fake_next_obs,
        fake_done=args.fake_done,
        fake_rew = args.fake_rew,
        fake_cost=args.fake_cost,
        fake_weight=args.fake_weight,
        shield_prop=args.shield_prop,
        cost_shield_thres=args.cost_shield_thres,
        actor_shield_thres=args.actor_shield_thres,
        action_dim=action_shape,
        prepare_steps=args.prepare_steps,
    )
    test_collector = FastCollector(policy, test_envs)
    # print(len(test_envs), len(train_envs))

    def stop_fn(reward, cost):
        return reward > args.reward_threshold and cost < args.cost_limit

    def checkpoint_fn():
        return {"model": policy.state_dict()}
    
    def checkpoint_fn_with_shield():
        return {"model": policy.state_dict(), "cost_shield": cost_shield_distill.state_dict(), "actor_shield": actor_shield_distill.state_dict()}

    if args.save_ckpt:
        if cost_shield_distill is None:
            logger.setup_checkpoint_fn(checkpoint_fn)
        else:
            logger.setup_checkpoint_fn(checkpoint_fn_with_shield)

    # trainer
    trainer = OffpolicyTrainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,
        max_epoch=args.epoch,
        batch_size=args.batch_size,
        cost_limit=args.cost_limit,
        step_per_epoch=args.step_per_epoch,
        update_per_step=args.update_per_step,
        episode_per_test=args.testing_num,
        episode_per_collect=args.episode_per_collect,
        stop_fn=stop_fn,
        logger=logger,
        resume_from_log=args.resume,
        save_model_interval=args.save_interval,
        verbose=args.verbose,
    )

    for epoch, epoch_stat, info in trainer:
        logger.store(tab="train", cost_limit=args.cost_limit)
        print(f"Epoch: {epoch}")
        print(info)

    if __name__ == "__main__":
        pprint.pprint(info)
        # Let's watch its performance!
        env_name = env_list[args.task]
        env = gym.make(env_name)
        env = StandardSafetyGymWrapper(env)
        policy.eval()
        collector = FastCollector(policy, env)
        result = collector.collect(n_episode=10, render=args.render)
        rews, lens, cost = result["rew"], result["len"], result["cost"]
        print(f"Final eval reward: {rews.mean()}, cost: {cost}, length: {lens.mean()}")

        policy.train()
        collector = FastCollector(policy, env)
        result = collector.collect(n_episode=10, render=args.render)
        rews, lens, cost = result["rew"], result["len"], result["cost"]
        print(f"Final train reward: {rews.mean()}, cost: {cost}, length: {lens.mean()}")


if __name__ == "__main__":
    train()
